{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "char_vocab_path = \"./data/char_vocabs.txt\" # 字典文件\n",
    "train_data_path = \"./data/train_data\" # 训练数据\n",
    "test_data_path = \"./data/test_data\" # 测试数据\n",
    "\n",
    "special_words = ['<PAD>', '<UNK>'] # 特殊词表示\n",
    "\n",
    "# \"BIO\"标记的标签\n",
    "label2idx = {\"O\": 0,\n",
    "             \"B-PER\": 1, \"I-PER\": 2,\n",
    "             \"B-LOC\": 3, \"I-LOC\": 4,\n",
    "             \"B-ORG\": 5, \"I-ORG\": 6\n",
    "             }\n",
    "# 索引和BIO标签对应\n",
    "idx2label = {idx: label for label, idx in label2idx.items()}\n",
    "\n",
    "# 读取字符词典文件\n",
    "with open(char_vocab_path, \"r\", encoding=\"utf8\") as fo:\n",
    "    char_vocabs = [line.strip() for line in fo]\n",
    "char_vocabs = special_words + char_vocabs\n",
    "\n",
    "# 字符和索引编号对应\n",
    "idx2vocab = {idx: char for idx, char in enumerate(char_vocabs)}\n",
    "vocab2idx = {char: idx for idx, char in idx2vocab.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取训练语料\n",
    "def read_corpus(corpus_path, vocab2idx, label2idx):\n",
    "    datas, labels = [], []\n",
    "    with open(corpus_path, encoding='utf-8') as fr:\n",
    "        lines = fr.readlines()\n",
    "    sent_, tag_ = [], []\n",
    "    for line in lines:\n",
    "        if line != '\\n':\n",
    "            [char, label] = line.strip().split()\n",
    "            sent_.append(char)\n",
    "            tag_.append(label)\n",
    "        else:\n",
    "            sent_ids = [vocab2idx[char] if char in vocab2idx else vocab2idx['<UNK>'] for char in sent_]\n",
    "            tag_ids = [label2idx[label] if label in label2idx else 0 for label in tag_]\n",
    "            datas.append(sent_ids)\n",
    "            labels.append(tag_ids)\n",
    "            sent_, tag_ = [], []\n",
    "    return datas, labels\n",
    "\n",
    "# 加载训练集\n",
    "train_datas, train_labels = read_corpus(train_data_path, vocab2idx, label2idx)\n",
    "# 加载测试集\n",
    "test_datas, test_labels = read_corpus(test_data_path, vocab2idx, label2idx)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2158, 347, 2572, 850, 679, 5930, 2308, 6084, 563, 3765, 181, 6256, 4897, 563, 3765, 5025, 406, 3903, 915, 4110, 6802, 327, 240, 315, 2683, 352, 651, 847, 6802, 3990, 629, 3648, 341, 651, 3542, 871, 4033, 4211, 3903, 4214, 3900, 6802, 6007, 3433, 6312, 5108, 5361, 2473, 775, 181, 1208, 3007, 571, 2984, 4144, 651, 3542, 3556, 182]\n",
      "['我', '们', '是', '受', '到', '郑', '振', '铎', '先', '生', '、', '阿', '英', '先', '生', '著', '作', '的', '启', '示', '，', '从', '个', '人', '条', '件', '出', '发', '，', '瞄', '准', '现', '代', '出', '版', '史', '研', '究', '的', '空', '白', '，', '重', '点', '集', '藏', '解', '放', '区', '、', '国', '民', '党', '毁', '禁', '出', '版', '物', '。']\n",
      "[0, 0, 0, 0, 0, 1, 2, 2, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 6, 0, 0, 0, 0, 0, 0]\n",
      "['O', 'O', 'O', 'O', 'O', 'B-PER', 'I-PER', 'I-PER', 'O', 'O', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'O']\n"
     ]
    }
   ],
   "source": [
    "print(train_datas[5])\n",
    "print([idx2vocab[idx] for idx in train_datas[5]])\n",
    "print(train_labels[5])\n",
    "print([idx2label[idx] for idx in train_labels[5]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6874 7\n",
      "padding sequences\n",
      "x_train shape: (50658, 100)\n",
      "x_test shape: (4631, 100)\n",
      "trainlabels shape: (50658, 100, 7)\n",
      "testlabels shape: (4631, 100, 7)\n",
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_1 (InputLayer)         (None, 100)               0         \n",
      "_________________________________________________________________\n",
      "masking_1 (Masking)          (None, 100)               0         \n",
      "_________________________________________________________________\n",
      "embedding_1 (Embedding)      (None, 100, 128)          879872    \n",
      "_________________________________________________________________\n",
      "bidirectional_1 (Bidirection (None, 100, 128)          98816     \n",
      "_________________________________________________________________\n",
      "time_distributed_1 (TimeDist (None, 100, 7)            903       \n",
      "_________________________________________________________________\n",
      "crf_1 (CRF)                  (None, 100, 7)            119       \n",
      "=================================================================\n",
      "Total params: 979,710\n",
      "Trainable params: 979,710\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n",
      "Train on 45592 samples, validate on 5066 samples\n",
      "Epoch 1/20\n",
      " 1184/45592 [..............................] - ETA: 4:04 - loss: 4.9087 - crf_viterbi_accuracy: 0.7343"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-4-e0e5d3351a65>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[0;32m     44\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     45\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcompile\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mcrf_loss\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;34m'adam'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mcrf_viterbi_accuracy\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 46\u001b[1;33m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrain_datas\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtrain_labels\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mepochs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mEPOCHS\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mvalidation_split\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m0.1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     47\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     48\u001b[0m \u001b[0mscore\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtest_datas\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtest_labels\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mBATCH_SIZE\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\keras\\engine\\training.py\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)\u001b[0m\n\u001b[0;32m   1037\u001b[0m                                         \u001b[0minitial_epoch\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0minitial_epoch\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1038\u001b[0m                                         \u001b[0msteps_per_epoch\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0msteps_per_epoch\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1039\u001b[1;33m                                         validation_steps=validation_steps)\n\u001b[0m\u001b[0;32m   1040\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1041\u001b[0m     def evaluate(self, x=None, y=None,\n",
      "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\keras\\engine\\training_arrays.py\u001b[0m in \u001b[0;36mfit_loop\u001b[1;34m(model, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)\u001b[0m\n\u001b[0;32m    197\u001b[0m                     \u001b[0mins_batch\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mins_batch\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mi\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtoarray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    198\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 199\u001b[1;33m                 \u001b[0mouts\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mf\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mins_batch\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    200\u001b[0m                 \u001b[0mouts\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mto_list\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mouts\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    201\u001b[0m                 \u001b[1;32mfor\u001b[0m \u001b[0ml\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mo\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mout_labels\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mouts\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\keras\\backend\\tensorflow_backend.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, inputs)\u001b[0m\n\u001b[0;32m   2713\u001b[0m                 \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_legacy_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2714\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 2715\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   2716\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2717\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0mpy_any\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mis_tensor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[1;32min\u001b[0m \u001b[0minputs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\keras\\backend\\tensorflow_backend.py\u001b[0m in \u001b[0;36m_call\u001b[1;34m(self, inputs)\u001b[0m\n\u001b[0;32m   2673\u001b[0m             \u001b[0mfetched\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_callable_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0marray_vals\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun_metadata\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2674\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 2675\u001b[1;33m             \u001b[0mfetched\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_callable_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0marray_vals\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   2676\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[0mfetched\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0moutputs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   2677\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m__call__\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1437\u001b[0m           ret = tf_session.TF_SessionRunCallable(\n\u001b[0;32m   1438\u001b[0m               \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_handle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mstatus\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1439\u001b[1;33m               run_metadata_ptr)\n\u001b[0m\u001b[0;32m   1440\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1441\u001b[0m           \u001b[0mproto_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import keras\n",
    "from keras.models import Sequential\n",
    "from keras.models import Model\n",
    "from keras.layers import Masking, Embedding, Bidirectional, LSTM, Dense, Input, TimeDistributed, Activation\n",
    "from keras.preprocessing import sequence\n",
    "from keras_contrib.layers import CRF\n",
    "from keras_contrib.losses import crf_loss\n",
    "from keras_contrib.metrics import crf_viterbi_accuracy\n",
    "from keras import backend as K\n",
    "K.clear_session()\n",
    "\n",
    "EPOCHS = 20\n",
    "BATCH_SIZE = 64\n",
    "EMBED_DIM = 128\n",
    "HIDDEN_SIZE = 64\n",
    "MAX_LEN = 100\n",
    "VOCAB_SIZE = len(vocab2idx)\n",
    "CLASS_NUMS = len(label2idx)\n",
    "print(VOCAB_SIZE, CLASS_NUMS)\n",
    "\n",
    "print('padding sequences')\n",
    "train_datas = sequence.pad_sequences(train_datas, maxlen=MAX_LEN)\n",
    "train_labels = sequence.pad_sequences(train_labels, maxlen=MAX_LEN)\n",
    "test_datas = sequence.pad_sequences(test_datas, maxlen=MAX_LEN)\n",
    "test_labels = sequence.pad_sequences(test_labels, maxlen=MAX_LEN)\n",
    "print('x_train shape:', train_datas.shape)\n",
    "print('x_test shape:', test_datas.shape)\n",
    "\n",
    "train_labels = keras.utils.to_categorical(train_labels, CLASS_NUMS)\n",
    "test_labels = keras.utils.to_categorical(test_labels, CLASS_NUMS)\n",
    "print('trainlabels shape:', train_labels.shape)\n",
    "print('testlabels shape:', test_labels.shape)\n",
    "\n",
    "## BiLSTM+CRF模型构建\n",
    "inputs = Input(shape=(MAX_LEN,), dtype='int32')\n",
    "x = Masking(mask_value=0)(inputs)\n",
    "x = Embedding(VOCAB_SIZE, EMBED_DIM, mask_zero=True)(x)\n",
    "x = Bidirectional(LSTM(HIDDEN_SIZE, return_sequences=True))(x)\n",
    "x = TimeDistributed(Dense(CLASS_NUMS))(x)\n",
    "outputs = CRF(CLASS_NUMS)(x)\n",
    "model = Model(inputs=inputs, outputs=outputs)\n",
    "model.summary()\n",
    "\n",
    "model.compile(loss=crf_loss, optimizer='adam', metrics=[crf_viterbi_accuracy])\n",
    "model.fit(train_datas, train_labels, epochs=EPOCHS, verbose=1, validation_split=0.1)\n",
    "\n",
    "score = model.evaluate(test_datas, test_labels, batch_size=BATCH_SIZE)\n",
    "print(model.metrics_names)\n",
    "print(score)\n",
    "\n",
    "# save model\n",
    "model.save(\"./model/ch_ner_model.h5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{0: 'O', 1: 'B-PER', 2: 'I-PER', 3: 'B-LOC', 4: 'I-LOC', 5: 'B-ORG', 6: 'I-ORG'}\n",
      "['中', '华', '人', '民', '共', '和', '国', '国', '务', '院', '总', '理', '周', '恩', '来', '在', '外', '交', '部', '长', '陈', '毅', '的', '陪', '同', '下', '，', '连', '续', '访', '问', '了', '埃', '塞', '俄', '比', '亚', '等', '非', '洲', '1', '0', '国', '以', '及', '阿', '尔', '巴', '尼', '亚', '。']\n",
      "[242, 787, 315, 3007, 584, 966, 1208, 1208, 720, 6275, 2008, 3682, 950, 2026, 2684, 1219, 1361, 301, 5940, 6195, 6263, 2986, 3903, 6279, 889, 218, 6802, 5804, 4495, 5426, 6209, 283, 1284, 1322, 452, 2994, 296, 4277, 6362, 3148, 15, 14, 1208, 343, 843, 6256, 1639, 1778, 1654, 296, 182]\n",
      "['B-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'B-PER', 'I-PER', 'I-PER', 'O', 'B-ORG', 'I-ORG', 'I-ORG', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'O', 'B-LOC', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'O']\n"
     ]
    }
   ],
   "source": [
    "from keras.models import load_model\n",
    "import numpy as np\n",
    "\n",
    "maxlen = 100\n",
    "sentence = \"中华人民共和国国务院总理周恩来在外交部长陈毅的陪同下，连续访问了埃塞俄比亚等非洲10国以及阿尔巴尼亚。\"\n",
    "\n",
    "sent_chars = list(sentence)\n",
    "sent2id = [vocab2idx[word] if word in vocab2idx else vocab2idx['<UNK>'] for word in sent_chars]\n",
    "\n",
    "sent2id_new = np.array([[0] * (maxlen-len(sent2id)) + sent2id[:maxlen]])\n",
    "y_pred = model.predict(sent2id_new)\n",
    "y_label = np.argmax(y_pred, axis=2)\n",
    "y_label = y_label.reshape(1, -1)[0]\n",
    "y_ner = [idx2label[i] for i in y_label][-len(sent_chars):]\n",
    "\n",
    "print(idx2label)\n",
    "print(sent_chars)\n",
    "print(sent2id)\n",
    "print(y_ner)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "中华人民共和国国务院 ORG\n",
      "周恩来 PER\n",
      "外交部 ORG\n",
      "陈毅 PER\n",
      "埃塞俄比亚 LOC\n",
      "非洲 LOC\n",
      "阿尔巴尼亚 LOC\n"
     ]
    }
   ],
   "source": [
    "# 对预测结果进行命名实体解析和提取\n",
    "def get_valid_nertag(input_data, result_tags):\n",
    "    result_words = []\n",
    "    start, end =0, 1 # 实体开始结束位置标识\n",
    "    tag_label = \"O\" # 实体类型标识\n",
    "    for i, tag in enumerate(result_tags):\n",
    "        if tag.startswith(\"B\"):\n",
    "            if tag_label != \"O\": # 当前实体tag之前有其他实体\n",
    "                result_words.append((input_data[start: end], tag_label)) # 获取实体\n",
    "            tag_label = tag.split(\"-\")[1] # 获取当前实体类型\n",
    "            start, end = i, i+1 # 开始和结束位置变更\n",
    "        elif tag.startswith(\"I\"):\n",
    "            temp_label = tag.split(\"-\")[1]\n",
    "            if temp_label == tag_label: # 当前实体tag是之前实体的一部分\n",
    "                end += 1 # 结束位置end扩展\n",
    "        elif tag == \"O\":\n",
    "            if tag_label != \"O\": # 当前位置非实体 但是之前有实体\n",
    "                result_words.append((input_data[start: end], tag_label)) # 获取实体\n",
    "                tag_label = \"O\"  # 实体类型置\"O\"\n",
    "            start, end = i, i+1 # 开始和结束位置变更\n",
    "    if tag_label != \"O\": # 最后结尾还有实体\n",
    "        result_words.append((input_data[start: end], tag_label)) # 获取结尾的实体\n",
    "    return result_words\n",
    "\n",
    "result_words = get_valid_nertag(sent_chars, y_ner)\n",
    "for (word, tag) in result_words:\n",
    "    print(\"\".join(word), tag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{0: 'O', 1: 'B-PER', 2: 'I-PER', 3: 'B-LOC', 4: 'I-LOC', 5: 'B-ORG', 6: 'I-ORG'}\n",
      "['中', '华', '人', '民', '共', '和', '国', '国', '务', '院', '总', '理', '周', '恩', '来', '在', '外', '交', '部', '长', '陈', '毅', '的', '陪', '同', '下', '，', '连', '续', '访', '问', '了', '埃', '塞', '俄', '比', '亚', '等', '非', '洲', '1', '0', '国', '以', '及', '阿', '尔', '巴', '尼', '亚', '。']\n",
      "[242, 787, 315, 3007, 584, 966, 1208, 1208, 720, 6275, 2008, 3682, 950, 2026, 2684, 1219, 1361, 301, 5940, 6195, 6263, 2986, 3903, 6279, 889, 218, 6802, 5804, 4495, 5426, 6209, 283, 1284, 1322, 452, 2994, 296, 4277, 6362, 3148, 15, 14, 1208, 343, 843, 6256, 1639, 1778, 1654, 296]\n",
      "['B-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'B-PER', 'I-PER', 'I-PER', 'O', 'B-ORG', 'I-ORG', 'I-ORG', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'O', 'B-LOC', 'I-LOC', 'O', 'O', 'O', 'O', 'O', 'B-LOC', 'I-LOC', 'I-LOC', 'I-LOC', 'I-LOC']\n"
     ]
    }
   ],
   "source": [
    "char_vocab_path = \"./data/char_vocabs.txt\" # 字典文件\n",
    "model_path = \"./model/ch_ner_model.h5\" # 模型文件\n",
    "\n",
    "ner_labels = {\"O\": 0, \"B-PER\": 1, \"I-PER\": 2, \"B-LOC\": 3, \"I-LOC\": 4, \"B-ORG\": 5, \"I-ORG\": 6}\n",
    "special_words = ['<PAD>', '<UNK>']\n",
    "MAX_LEN = 100\n",
    "\n",
    "with open(char_vocab_path, \"r\", encoding=\"utf8\") as fo:\n",
    "    char_vocabs = [line.strip() for line in fo]\n",
    "char_vocabs = special_words + char_vocabs\n",
    "\n",
    "idx2vocab = {idx: char for idx, char in enumerate(char_vocabs)}\n",
    "vocab2idx = {char: idx for idx, char in idx2vocab.items()}\n",
    "\n",
    "idx2label = {idx: label for label, idx in ner_labels.items()}\n",
    "\n",
    "sentence = \"中华人民共和国国务院总理周恩来在外交部长陈毅的陪同下，连续访问了埃塞俄比亚等非洲10国以及阿尔巴尼亚\"\n",
    "\n",
    "sent2id = [vocab2idx[word] if word in vocab2idx else vocab2idx['<UNK>'] for word in sentence]\n",
    "\n",
    "sent2input = np.array([[0] * (MAX_LEN-len(sent2id)) + sent2id[:MAX_LEN]])\n",
    "\n",
    "model = load_model(model_path, custom_objects={'CRF': CRF}, compile=False)\n",
    "y_pred = model.predict(sent2input)\n",
    "\n",
    "y_label = np.argmax(y_pred, axis=2)\n",
    "y_label = y_label.reshape(1, -1)[0]\n",
    "y_ner = [idx2label[i] for i in y_label][-len(sentence):]\n",
    "\n",
    "print(idx2label)\n",
    "print(sent_chars)\n",
    "print(sent2id)\n",
    "print(y_ner)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
